# ======================================================================
# Copyright TOTAL / CERFACS / LIRMM (03/2020)
# Contributor: Adrien Suau (<adrien.suau@cerfacs.fr>
#                           <adrien.suau@lirmm.fr>)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

import time as timemodule
import logging
import typing as ty

# qat specific
from qat.lang.AQASM.program import Program

from qaths.applications.wave_equation.evolve_1D_dirichlet import (
    compute_r,
    evolve_1d_dirichlet_no_repetition,
)
from qaths.applications.wave_equation.time_adaptation import (
    adapt_dirichlet_evolution_time,
)

# Wave equation quantum solver
from qaths.applications.wave_equation.utils import (
    compute_qubit_number_from_considered_points_1d,
)

# qaths translators
from qaths._cli.data_generation.utils.gate_count import gate_count_to_ibmq

logger = logging.getLogger("qaths._cli.data_generation.utils.circuit_analysers")

try:
    from qat.noisy.util import start_and_end_times

    _gt = {
        "U1": 0,
        "U2": 100 + 20,
        "U3": 100 + 20 + 100 + 20,
        "CX": 100 + 20 + 347 + 20 + 100 + 20 + 347 + 20,
    }
    ibmq_gate_times = {
        "X": _gt["U3"],
        "Y": _gt["U3"],
        "Z": _gt["U1"],
        "H": _gt["U2"],
        "S": _gt["U1"],
        "T": _gt["U1"],
        "RX": _gt["U3"],
        "RY": _gt["U3"],
        "RZ": _gt["U1"],
        "CNOT": _gt["CX"],
        "C-PH": 3 * _gt["U1"] + 2 * _gt["CX"],
        "CCNOT": 7 * _gt["U1"] + 2 * _gt["U2"] + 6 * _gt["CX"],
        # A C-C-PH is a controlled-C-PH, so according to the definition of C-PH we have:
        #  - 3 C-U1 which are in facts 3 C-PH, i.e. {"U1": 3*3, "CX": 2*3}
        #  - 2 CCNOT, i.e. {"U1": 7*2, "U2": 2*2, "CX": 6*2}
        # Adding the 2 dicts gives us:
        "C-C-PH": 23 * _gt["U1"] + 4 * _gt["U2"] + 18 * _gt["CX"],
        # See https://quantumcomputing.stackexchange.com/q/9444/1386
        "C-H": 2 * _gt["U3"] + _gt["CX"],
    }

    # A small "hack" of a function that creates a space time diagram of the quantum
    # circuit using a map of gate times
    def _get_melbourne_depth(circuit):
        """
        Returns the depth of a circuit given some gate passing times.
        """

        _, start_times, end_times = start_and_end_times(circuit, ibmq_gate_times)
        return max(max(e for e in end_times))


except ImportError:
    logger.warning(
        "qat.noisy.util not found. You are probably using MyQLM that does not include "
        "this feature. The Melbourne timings and depths will be set to 0."
    )

    def _get_melbourne_depth(circuit):
        return 0


def get_statistics(
    discretisation_points_number: int,
    time: float,
    epsilon: float,
    trotter_order: int,
    linked_implementations: ty.Optional[ty.List] = None,
    is_wave_equation_solver: bool = False,
) -> dict:
    """Compute and return some statistics on the 1-D wave equation solver.

    The statistics computed are:

        1. Qubit count (arity) of the generated circuit.
        2. Gate counts (generated circuit, expanded controls, translated to IBMQ).
        3. Generation timings (generation, application, to_circ(), IBMQ expand).
        4. Execution timings (Atos linalg, IBMQ)

    :param discretisation_points_number: the number of discretisation points used to
        solve the wave equation.
    :param time: the for which we want to solve the wave equation.
    :param epsilon: the precision of the desired solution.
    :param trotter_order: the order of the Trotter-Suzuki formula to use.
    :param linked_implementations: forwarded to the "link" kwarg of the to_circ method.
    :param is_wave_equation_solver: boolean flag that should be set to True if the
        caller wants to get the data for the wave equation solver. The default is to
        return the data for the Hamiltonian simulation procedure (no time adaptation).
    :return: a dictionary with the statistics.
    """
    res = {"params": (discretisation_points_number, time, epsilon, trotter_order)}

    ###################################
    # 1. Generate the quantum circuit #
    ###################################
    # 1. Compute the number of qubits we will need
    n = compute_qubit_number_from_considered_points_1d(discretisation_points_number - 2)

    simulation_time = time
    if is_wave_equation_solver:
        simulation_time = adapt_dirichlet_evolution_time(
            time, discretisation_points_number
        )
    repetition_number = compute_r(simulation_time, epsilon, 2, 1.0, trotter_order)

    # 2. Create the QRoutine solving the wave equation and time it.
    start_gen = timemodule.time()
    solve_gate_no_repeat = evolve_1d_dirichlet_no_repetition(
        simulation_time / repetition_number, discretisation_points_number, trotter_order
    )
    end_gen = timemodule.time()
    # Delete the 2-controlled gates for a more realistic number of gates.
    # solve_gate_no_repeat_no_ctrl = suppr_ctrl(solve_gate_no_repeat)

    # 3. Initialise the quantum program and allocate the needed qubits.
    prog = Program()
    x = prog.qalloc(n)
    ancilla = prog.qalloc(solve_gate_no_repeat.arity - n)

    # 4. Apply the QRoutine and time it.
    start_apply = timemodule.time()
    prog.apply(solve_gate_no_repeat, x, ancilla)
    end_apply = timemodule.time()

    # 5. Translate the quantum program to a Circuit and time it.
    start_circ = timemodule.time()
    circ = prog.to_circ(link=linked_implementations, inline=True)
    end_circ = timemodule.time()

    # display_circ = prog.to_circ(link=linked_implementations)
    # display(circ, max_depth=4)

    ###################################
    # 2.        Optimisations         #
    ###################################
    # 1. Expand the gates with Graphopt and keep the expended circuit.
    # start_expand_graphopt = timemodule.time()
    # circ_exp = optimize_circuit(circ, Graphopt(expandonly=True))
    # end_expand_graphopt = timemodule.time()

    # 2. Optimise the circuit with Graphopt and keep the optimised circuit.
    # start_opti_graphopt = timemodule.time()
    # circ_opt = optimize_circuit(circ_exp, Graphopt())
    # end_opti_graphopt = timemodule.time()

    # 3. Expand the original circuit to the IBMQ basis.
    # start_ibmq_exp = timemodule.time()
    # ibmq_circ = circ_to_ibmq_basis(circ)
    # end_ibmq_exp = timemodule.time()

    # 4. Apply some optimisations on the IBMQ circuit.
    # No optimisation implemented for the moment...
    # start_ibmq_opt = timemodule.time()
    # ibmq_circ_opt = ibmq_circ
    # end_ibmq_opt = timemodule.time()

    #####################################
    # 3. Extraction of other statistics #
    #####################################
    # 0. Definition of useful functions
    def get_gate_name(gate, gateDic) -> str:
        """Return the name of the given gate."""
        if gate.syntax is not None:
            return gate.syntax.name
        elif gate.is_ctrl:
            return f"CTRL({get_gate_name(gateDic[gate.subgate], gateDic)})"
        elif gate.subgate is not None:
            return get_gate_name(gateDic[gate.subgate], gateDic)
        else:
            raise RuntimeError(f"Unsupported operation in get_gate_name: {gate}")

    def compute_gate_count(circuit, r: int):
        """Return a dictionary with gate names as key and the corresponding gate
        count."""
        gate_count = dict()
        for (op_name, parameters, qubits) in circuit.iterate_simple():
            # We add "r" at each iteration because the circuit is repeated r times.
            gate_count[op_name] = gate_count.get(op_name, 0) + r
        return gate_count

    # 1. Compute the gate count
    circ_gate_count = compute_gate_count(circ, repetition_number)
    ibmq_circ_no_ctrl_gate_count = gate_count_to_ibmq(circ_gate_count)

    # 2. Compute the running time on the real quantum hardware Melbourne.
    try:
        melbourne_exec_time_ns = _get_melbourne_depth(circ) * repetition_number
        res["Melbourne_execution_time_ns"] = melbourne_exec_time_ns
    except NameError:
        pass

    ###################################
    # 4.     Statistics gathering     #
    ###################################
    # Arity
    res["arity"] = circ.nbqbits
    # Melbourne execution time
    # Timings
    res["Generation_time_ms"] = (end_gen - start_gen) * 10 ** 3
    res["Application_time_ms"] = (end_apply - start_apply) * 10 ** 3
    res["To_circuit_time_ms"] = (end_circ - start_circ) * 10 ** 3
    # res["IBMQ_expand_time_ms"] = (end_ibmq_exp - start_ibmq_exp) * 10 ** 3
    # Gate count
    res["circ_gate_count"] = circ_gate_count
    res["ibmq_circ_gate_count"] = ibmq_circ_no_ctrl_gate_count

    return res


def get_stats_from_tuple(tup):
    return get_statistics(*tup)
